{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "single-view-mpi.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HLgdyo1jgQmk",
        "colab_type": "text"
      },
      "source": [
        "Copyright 2020 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \\\"License\\\")"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2XCUxOexgRF7",
        "colab_type": "code",
        "colab": {},
        "cellView": "form"
      },
      "source": [
        "#@title License\n",
        "# Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KN8Ud050qB4p",
        "colab_type": "text"
      },
      "source": [
        "# Single image to MPI example Colab\n",
        "\n",
        "This Colab is part of code for the paper ___Single-view view synthesis with multiplane images___, and may be found at <br>https://github.com/google-research/google-research/tree/master/single_view_mpi.\n",
        "\n",
        "The project site is at https://single-view-mpi.github.io/.\n",
        "\n",
        "Choose __Run all__ from the Runtime menu to:\n",
        "* set up the network and load our trained model,\n",
        "* apply it to an RGB input to generate a 32-layer MPI,\n",
        "* show individual MPI layers and synthesized disparity,\n",
        "* render novel views from different camera positions,\n",
        "* visualize the resulting MPI in an HTML-based viewer.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EpDKjdWWk7zi",
        "colab_type": "text"
      },
      "source": [
        "## Download library code, model weights, and example image."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eKire3Obk2ra",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!echo Fetching code from github...\n",
        "!apt install subversion\n",
        "!svn export --force https://github.com/google-research/google-research/trunk/single_view_mpi\n",
        "\n",
        "!echo\n",
        "!echo Fetching trained model weights...\n",
        "!rm single_view_mpi_full_keras.tar.gz\n",
        "!rm -rf single_view_mpi_full\n",
        "!wget https://storage.googleapis.com/stereo-magnification-public-files/models/single_view_mpi_full_keras.tar.gz\n",
        "!tar -xzvf single_view_mpi_full_keras.tar.gz\n",
        "\n",
        "!echo\n",
        "!echo Fetching example image...\n",
        "!rm -f input.png\n",
        "!wget https://single-view-mpi.github.io/mpi/7/input.png\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DLxv0DvYqGhy",
        "colab_type": "text"
      },
      "source": [
        "## Set up the model\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "saA4-roFkvsA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "pip install -r single_view_mpi/requirements.txt"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VlX8AfkNzHfR",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "from single_view_mpi.libs import mpi\n",
        "from single_view_mpi.libs import nets\n",
        "\n",
        "input = tf.keras.Input(shape=(None, None, 3))\n",
        "output = nets.mpi_from_image(input)\n",
        "\n",
        "model = tf.keras.Model(inputs=input, outputs=output)\n",
        "print('Model created.')\n",
        "# Our full model, trained on RealEstate10K.\n",
        "model.load_weights('single_view_mpi_full_keras/single_view_mpi_keras_weights')\n",
        "print('Weights loaded.')\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_F_UJLZHqunU",
        "colab_type": "text"
      },
      "source": [
        "## Generate an MPI from an input image, show layers and disparity"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f_anlB3jqwm2",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "plt.rcParams[\"figure.figsize\"] = (20, 10)\n",
        "\n",
        "# Input image\n",
        "inputfile = 'input.png'\n",
        "input_rgb = tf.image.decode_image(tf.io.read_file(inputfile), dtype=tf.float32)\n",
        "\n",
        "# Generate MPI\n",
        "layers = model(input_rgb[tf.newaxis])[0]\n",
        "depths = mpi.make_depths(1.0, 100.0, 32).numpy()\n",
        "\n",
        "# Layers is now a tensor of shape [L, H, W, 4].\n",
        "# This represents an MPI with L layers, each of height H and width W, and\n",
        "# each with an RGB+Alpha 4-channel image.\n",
        "\n",
        "# Depths is a tensor of shape [L] which gives the depths of the L layers.\n",
        "\n",
        "# Display layer images\n",
        "for i in range(32):\n",
        "  plt.subplot(4, 8, i+1)\n",
        "  plt.imshow(layers[i])\n",
        "  plt.axis('off')\n",
        "  plt.title('Layer %d' % i, loc='left')\n",
        "plt.show()\n",
        "\n",
        "# Display computed disparity\n",
        "disparity = mpi.disparity_from_layers(layers, depths)\n",
        "plt.imshow(disparity[..., 0])\n",
        "plt.axis('off')\n",
        "plt.title('Synthesized disparity')\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D4r2c9qAbI6H",
        "colab_type": "text"
      },
      "source": [
        "## Generate new views from nearby camera positions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M-gFmGpGbIGO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# The reference camera position can just be the identity\n",
        "reference_pose = tf.constant(\n",
        "    [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]])\n",
        "\n",
        "# Accurate intrinsics are only important if we are trying to match a ground\n",
        "# truth output. Here we just give intrinsics for a 16:9 image with the\n",
        "# principal point in the center.\n",
        "intrinsics = tf.constant([1.0, 1.0 * 16/9, 0.5, 0.5])\n",
        "\n",
        "def render(xoffset, yoffset, zoffset):\n",
        "  # The translation is the final column of the pose matrix\n",
        "  target_pose = tf.constant(\n",
        "    [[1.0, 0.0, 0.0, -xoffset],\n",
        "     [0.0, 1.0, 0.0, -yoffset],\n",
        "     [0.0, 0.0, 1.0, -zoffset]])\n",
        "  image = mpi.render(layers, depths,\n",
        "                     reference_pose, intrinsics,  # Reference view\n",
        "                     target_pose, intrinsics,  # Target view\n",
        "                     height=512, width=910)\n",
        "  return image\n",
        "\n",
        "# First move the camera along the X axis (left to right):\n",
        "for i in range(5):\n",
        "  xoffset = (i - 2) * 0.05\n",
        "  plt.subplot(1, 5, i + 1)\n",
        "  plt.imshow(render(xoffset, 0.0, 0.0))\n",
        "  plt.title('xoff = %f' % xoffset)\n",
        "  plt.axis('off')\n",
        "plt.show()\n",
        "\n",
        "# And next along the Z-axis (moving forwards):\n",
        "for i in range(5):\n",
        "  zoffset = (i - 2) * 0.15\n",
        "  plt.subplot(1, 5, i + 1)\n",
        "  plt.imshow(render(0.0, 0.0, zoffset))\n",
        "  plt.title('zoff = %f' % zoffset)\n",
        "  plt.axis('off')\n",
        "plt.show()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gz8XgxWcqau-",
        "colab_type": "text"
      },
      "source": [
        "## A simple MPI-viewer using HTML + CSS transforms"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8tKOgm4wqcKg",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "import base64\n",
        "\n",
        "def imgurl(image):\n",
        "  # We resize layers to 512x288 so the whole stack can be serialized in a\n",
        "  # Colab for the HTML viewer without hitting the memory restriction. Outside\n",
        "  # Colab there is no such restriction and 512x512 layers could be used.\n",
        "  image = tf.image.resize(image, [288, 512], method='area')\n",
        "  data = tf.image.encode_png(\n",
        "      tf.image.convert_image_dtype(image, tf.uint8)).numpy()\n",
        "  dataurl = 'data:image/png;base64,{}'.format(base64.b64encode(data).decode())\n",
        "  return dataurl\n",
        "\n",
        "def html_viewer(layers, depths):\n",
        "  html = []\n",
        "\n",
        "  html.append('''\n",
        "  <style>\n",
        "  #view {\n",
        "    position: relative;\n",
        "    overflow: hidden;\n",
        "    border: 20px solid black;\n",
        "    width: 512px;\n",
        "    height: 288px;\n",
        "    perspective: 500px;\n",
        "    background: #888;\n",
        "  }\n",
        "  #mpi {\n",
        "    transform-style: preserve-3d; -webkit-transform-style: preserve-3d;\n",
        "    height: 100%;\n",
        "    width: 100%;\n",
        "    pointer-events: none;\n",
        "  }\n",
        "  .layer {\n",
        "    position: absolute;\n",
        "    background-size: 100% 100%;\n",
        "    background-repeat: no-repeat;\n",
        "    background-position: center;\n",
        "    width: 100%;\n",
        "    height: 100%;\n",
        "  }\n",
        "  </style>''')\n",
        "  html.append('<h1>MPI Viewer</h1>Hover over the image to control the view.')\n",
        "  html.append('<div id=view><div id=mpi>')\n",
        "  for i in range(len(depths)):\n",
        "    depth = depths[i]\n",
        "    url = imgurl(layers[i])\n",
        "    html.append('''\n",
        "        <div class=layer\n",
        "             style=\"transform: scale(%.3f) translateZ(-%.3fpx);\n",
        "             background-image: url(%s)\"></div>''' % (depth, depth, url))\n",
        "\n",
        "  html.append('</div></div>')\n",
        "  html.append('''\n",
        "  <script>\n",
        "  function setView(mpi, x, y) {\n",
        "    x = 2*x - 1;\n",
        "    y = 2*y - 1;\n",
        "    rx = (-1.5 * y).toFixed(2);\n",
        "    ry = (2.0 * x).toFixed(2);\n",
        "    // Put whatever CSS transform you want in here.\n",
        "    mpi.style.transform =\n",
        "        `rotateX(${rx}deg) rotateY(${ry}deg) translateZ(500px) scaleZ(500)`;\n",
        "  }\n",
        "\n",
        "  view = document.querySelector('#view');\n",
        "  mpi = document.querySelector('#mpi');\n",
        "  setView(mpi, 0.5, 0.5);\n",
        "\n",
        "  // View animates by itself, or you can hover over the image to control it.\n",
        "  let t = 0;\n",
        "  let animate = true;\n",
        "  function tick() {\n",
        "    if (!animate) {\n",
        "      return;\n",
        "    }\n",
        "    t = (t + 1) % 300;\n",
        "    r = Math.PI * 2 * t / 300;\n",
        "    setView(mpi, 0.5 + 0.3 * Math.cos(r), 0.5 + 0.3 * Math.sin(r));\n",
        "    requestAnimationFrame(tick);\n",
        "  }\n",
        "  tick();\n",
        "\n",
        "  view.addEventListener('mousemove',\n",
        "    (e) => {animate=false; setView(mpi, e.offsetX/view.offsetWidth, e.offsetY/view.offsetHeight);});\n",
        "  view.addEventListener('mouseleave',\n",
        "    (e) => {animate=true; tick();});\n",
        "  </script>\n",
        "  ''')\n",
        "  return ''.join(html)\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5CksRfJBP0Et",
        "colab_type": "text"
      },
      "source": [
        "## View the MPI in a live 3D web viewer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7vbB3JcaqirL",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "from IPython.display import display, HTML\n",
        "\n",
        "display(HTML(html_viewer(layers, depths)))"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}